import math
import copy
import gym
import random
import numpy as np
import statistics
import pickle

# Import your updated custom/stochastic envs
import Continuous_CartPole      # e.g. "Continuous-CartPole-v0"
import Continuous_Pendulum      # e.g. "StochasticPendulum-v0"
import continuous_mountain_car  # e.g. "StochasticMountainCarContinuous-v0"
import continuous_acrobot       # e.g. "StochasticContinuousAcrobot-v0"
import improved_hopper
import improved_ant
import improved_walker2d

from SnapshotENV import SnapshotEnv
from hoo import HOO  # The HOO class from hoo.py

###########################################################################
# 1) environment IDs
###########################################################################
env_names = [
    "Continuous-CartPole-v0",
    "StochasticPendulum-v0",
    "StochasticMountainCarContinuous-v0",
    "StochasticContinuousAcrobot-v0",
    "ImprovedHopper-v0"
]

###########################################################################
# 2) noise configs or constructor kwargs for each environment
###########################################################################
ENV_NOISE_CONFIG = {
    "Continuous-CartPole-v0": {
        "action_noise_scale": 0.05, #0.05
        "dynamics_noise_scale": 0.5, #0.01
        "obs_noise_scale": 0.0
    },
    "StochasticPendulum-v0": {
        "action_noise_scale": 0.02, #0.02,
        "dynamics_noise_scale": 0.1, #0.01,
        "obs_noise_scale": 0.01
        # or pass "g": 9.8 if you want a different gravity, etc.
    },
    "StochasticMountainCarContinuous-v0": {
        "action_noise_scale":  0.05, #0.03,
        "dynamics_noise_scale": 0.5, #0.01,
        "obs_noise_scale": 0.0
    },
    "StochasticContinuousAcrobot-v0": {
        "action_noise_scale": 0.05, #0.05,
        "dynamics_noise_scale": 0.7,  #0.01,
        "obs_noise_scale": 0.01
    },
    "ImprovedHopper-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    },
    "ImprovedWalker2d-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    },
    "ImprovedAnt-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    }
}

###########################################################################
# 3) Global config
###########################################################################
num_seeds = 20
TEST_ITERATIONS = 150
discount = 0.99
MAX_MCTS_DEPTH = 100

# We'll do iteration counts in a geometric progression
base = 1000 ** (1.0 / 15.0)
samples = [int(3 * (base ** i)) for i in range(16)]
samples_to_use = samples[0:6]

###########################################################################
# Dimension-adaptive HOO parameters
###########################################################################
def get_hoo_params(dim):
    """Get appropriate HOO parameters based on action dimension"""
    if dim == 1:
        # Original parameters for 1D
        rho = 2 ** (-2 / dim)
        nu = 4 * dim
    elif dim <= 3:
        # Adjusted for low-dimensional (2D, 3D)
        rho = 2 ** (-1.5 / dim)
        nu = 3.5 * dim
    elif dim <= 8:
        # Adjusted for medium-dimensional (4D-8D)
        rho = 2 ** (-1 / dim)
        nu = 2.5 * dim
    else:
        # Adjusted for high-dimensional (9D+)
        rho = 2 ** (-0.7 / dim)
        nu = 1.8 * dim

    return rho, nu

###########################################################################
# Node class for HOOT (HOO over trees) with dimension-aware parameters
###########################################################################
class Node:
    def __init__(self, snapshot, obs, is_done, parent, dim, min_action, max_action):
        self.parent = parent
        self.snapshot = snapshot
        self.obs = obs
        self.is_done = is_done
        self.children = {}
        self.immediate_reward = 0
        self.dim = dim

        # Dimension-adaptive HOO hyperparams
        rho, nu = get_hoo_params(dim)

        self.hoo = HOO(
            dim=dim,
            nu=nu,
            rho=rho,
            min_value=min_action,
            max_value=max_action
        )

    def selection(self, depth, env, max_depth):
        # If done or depth-limit, no further expansion
        if self.is_done or depth >= max_depth:
            return 0.0

        # 1) select an action from HOO
        raw_action = self.hoo.select_action()  # shape=(dim,) float
        action = raw_action.astype(np.float32) # cast to float32

        act_key = tuple(action)
        # 2) if child node exists, recurse; else expand
        if act_key in self.children:
            child = self.children[act_key]
            val = child.selection(depth + 1, env, max_depth)
            # update HOO with total return from child
            self.hoo.update(val + child.immediate_reward)
            return child.immediate_reward + val
        else:
            # FIXED: Check if we have a valid snapshot before calling get_result
            if self.snapshot is None:
                # This should only happen at the root node on the first iteration
                # Get current snapshot
                self.snapshot = env.get_snapshot()

            # Expand => get next state from environment
            snapshot, obs, r, done, _ = env.get_result(self.snapshot, action)
            child = Node(
                snapshot=snapshot,
                obs=obs,
                is_done=done,
                parent=self,
                dim=self.dim,
                min_action=self.hoo.min_value,
                max_action=self.hoo.max_value
            )
            child.immediate_reward = r
            self.children[act_key] = child

            val = child.selection(depth + 1, env, max_depth)
            self.hoo.update(r + val)
            return r + val

    def delete_subtree(self, node):
        for akey in node.children:
            node.delete_subtree(node.children[akey])
        del node

# Rollout-based version for high-dimensional environments
class NodeWithRollout(Node):
    def __init__(self, snapshot, obs, is_done, parent, dim, min_action, max_action):
        super().__init__(snapshot, obs, is_done, parent, dim, min_action, max_action)
        self.rollout_depth = min(20, MAX_MCTS_DEPTH // 2)  # Adaptive rollout depth

    def rollout(self, env, max_depth):
        """Perform random rollout from current state"""
        if self.is_done:
            return 0.0

        env.load_snapshot(self.snapshot)
        total = 0.0
        discount_factor = 1.0

        for _ in range(min(max_depth, self.rollout_depth)):
            # Generate random action
            action = np.random.uniform(
                low=self.hoo.min_value,
                high=self.hoo.max_value,
                size=self.dim
            ).astype(np.float32)

            obs, r, done, _ = env.step(action)
            total += r * discount_factor
            discount_factor *= discount

            if done:
                break

        return total

    def selection(self, depth, env, max_depth):
        # If done or depth-limit, perform rollout for value estimation
        if self.is_done or depth >= max_depth:
            if self.dim > 8:  # Use rollout for high-dimensional environments
                return self.rollout(env, max_depth - depth)
            else:
                return 0.0

        # 1) select an action from HOO
        raw_action = self.hoo.select_action()
        action = raw_action.astype(np.float32)

        act_key = tuple(action)
        # 2) if child node exists, recurse; else expand
        if act_key in self.children:
            child = self.children[act_key]
            val = child.selection(depth + 1, env, max_depth)
            self.hoo.update(val + child.immediate_reward)
            return child.immediate_reward + val
        else:
            # FIXED: Check if we have a valid snapshot before calling get_result
            if self.snapshot is None:
                # This should only happen at the root node on the first iteration
                # Get current snapshot
                self.snapshot = env.get_snapshot()

            # Expand => get next state from environment
            snapshot, obs, r, done, _ = env.get_result(self.snapshot, action)
            child = NodeWithRollout(
                snapshot=snapshot,
                obs=obs,
                is_done=done,
                parent=self,
                dim=self.dim,
                min_action=self.hoo.min_value,
                max_action=self.hoo.max_value
            )
            child.immediate_reward = r
            self.children[act_key] = child

            val = child.selection(depth + 1, env, max_depth)
            self.hoo.update(r + val)
            return r + val

###########################################################################
# 4) Main experiment
###########################################################################
if __name__ == "__main__":
    results_filename = "hoot_results.txt"
    f_out = open(results_filename, "a")

    for envname in env_names:
        # A) Build environment (with any noise config)
        stoch_kwargs = ENV_NOISE_CONFIG.get(envname, {})
        base_env = gym.make(envname, **stoch_kwargs).env

        # B) set dimension + action range + search depth
        if envname == "Continuous-CartPole-v0":
            dim = 1
            max_depth = 50
            min_action = base_env.min_action
            max_action = base_env.max_action
        elif envname == "StochasticPendulum-v0":
            dim = 1
            max_depth = 50
            min_action = -2.0
            max_action = 2.0
        elif envname == "StochasticMountainCarContinuous-v0":
            dim = 1
            max_depth = 50
            min_action = -1.0
            max_action = 1.0
        elif envname == "StochasticContinuousAcrobot-v0":
            dim = 1
            max_depth = 50
            min_action = -1.0
            max_action = 1.0
        elif envname == "ImprovedHopper-v0":
            dim = 3
            max_depth = 100
            min_action = -1.0
            max_action = 1.0
        elif envname == "ImprovedWalker2d-v0":
            min_action = -1.0
            max_action = 1.0
            dim = 6
            max_depth = 100
        elif envname == "ImprovedAnt-v0":
            min_action = -1.0
            max_action = 1.0
            dim = 8
            max_depth = 100
        else:
            # fallback
            dim = 1
            max_depth = 50
            min_action = -1.0
            max_action = 1.0

        print(f"\nEnvironment: {envname}")
        print(f"Action dimension: {dim}")
        print(f"Max depth: {max_depth}")

        # Print HOO parameters
        rho, nu = get_hoo_params(dim)
        print(f"HOO parameters - rho: {rho:.4f}, nu: {nu:.2f}")

        # Choose node type based on dimensionality
        NodeClass = NodeWithRollout if dim > 6 else Node
        print(f"Using {'rollout-enhanced' if dim > 6 else 'standard'} HOOT algorithm")

        # C) Wrap environment in SnapshotEnv
        planning_env = SnapshotEnv(gym.make(envname, **stoch_kwargs).env)
        root_obs_ori = planning_env.reset()
        root_snapshot_ori = planning_env.get_snapshot()

        # We'll loop over the iteration counts in samples_to_use
        for ITERATIONS in samples_to_use:
            seed_returns = []
            for seed_i in range(num_seeds):
                random.seed(seed_i)
                np.random.seed(seed_i)

                # copy original snapshot
                root_obs = copy.copy(root_obs_ori)
                root_snapshot = copy.copy(root_snapshot_ori)

                # build root node
                root = NodeClass(
                    snapshot=root_snapshot,
                    obs=root_obs,
                    is_done=False,
                    parent=None,
                    dim=dim,
                    min_action=min_action,
                    max_action=max_action
                )

                # plan
                for _ in range(ITERATIONS):
                    root.selection(depth=0, env=planning_env, max_depth=max_depth)

                # test
                test_env = pickle.loads(root_snapshot)
                total_reward = 0.0
                current_discount = 1.0
                done = False

                for i in range(TEST_ITERATIONS):
                    # pick best action from root's HOO
                    raw_best_action = root.hoo.get_point()
                    best_action = raw_best_action.astype(np.float32)

                    s, r, done, _ = test_env.step(best_action)
                    total_reward += r * current_discount
                    current_discount *= discount

                    if done:
                        test_env.close()
                        break

                    # prune other children
                    for akey in list(root.children.keys()):
                        if akey != tuple(best_action):
                            root.delete_subtree(root.children[akey])
                            del root.children[akey]

                    # re-root
                    if tuple(best_action) not in root.children:
                        # FIXED: Create new child with current snapshot instead of None
                        snap2, obs, r_new, done_new, _ = planning_env.get_result(root.snapshot, best_action)
                        cnode = NodeClass(
                            snapshot=snap2,
                            obs=obs,
                            is_done=done_new,
                            parent=None,
                            dim=dim,
                            min_action=min_action,
                            max_action=max_action
                        )
                        cnode.immediate_reward = r_new
                        root.children[tuple(best_action)] = cnode
                    root = root.children[tuple(best_action)]
                    root.parent = None

                    # re-plan (with fewer iterations for high-dimensional envs)
                    plan_iterations = ITERATIONS if dim <= 6 else max(ITERATIONS // 2, 100)
                    for _ in range(plan_iterations):
                        root.selection(depth=0, env=planning_env, max_depth=max_depth)

                if not done:
                    test_env.close()

                seed_returns.append(total_reward)

            mean_return = statistics.mean(seed_returns)
            std_return = statistics.pstdev(seed_returns)
            interval = 2.0 * std_return

            msg = (f"Env={envname}, ITER={ITERATIONS}: "
                   f"Mean={mean_return:.3f} ± {interval:.3f} "
                   f"(over {num_seeds} seeds)")
            print(msg)
            f_out.write(msg + "\n")
            f_out.flush()

    f_out.close()
    print(f"Done! Results saved to", results_filename)
